{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 63,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Files already downloaded and verified\n"
     ]
    }
   ],
   "source": [
    "from torchvision import datasets, transforms\n",
    "\n",
    "train_data = datasets.CIFAR100('./dataset', transform=transforms.Compose([transforms.ToTensor()]), download=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 64,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch import nn\n",
    "import torch\n",
    "from torch.autograd import Variable\n",
    "\n",
    "in_channels = 10\n",
    "c_m = 4\n",
    "c_n = 3\n",
    "\n",
    "class DoubleAttention(nn.Module):\n",
    "    def __init__(self, in_channels, c_m, c_n,k =1 ):\n",
    "        super(DoubleAttentionLayer, self).__init__()\n",
    "\n",
    "        self.K           = k\n",
    "        self.c_m = c_m\n",
    "        self.c_n = c_n\n",
    "        self.softmax     = nn.Softmax()\n",
    "        self.in_channels = in_channels\n",
    "\n",
    "        self.convA = nn.Conv2d(in_channels, c_m, 1)\n",
    "        self.convB = nn.Conv2d(in_channels, c_n, 1)\n",
    "        self.convV = nn.Conv2d(in_channels, c_n, 1)\n",
    "\n",
    "    def forward(self, x):\n",
    "\n",
    "        b, c, h, w = x.size()\n",
    "\n",
    "        assert c == self.in_channels,'input channel not equal!'\n",
    "\n",
    "        A = self.convA(x)\n",
    "        B = self.convB(x)\n",
    "        V = self.convV(x)\n",
    "\n",
    "        batch = int(b/self.K)\n",
    "\n",
    "        tmpA = A.view( batch, self.K, self.c_m, h*w ).permute(0,2,1,3).view( batch, self.c_m, self.K*h*w )\n",
    "        tmpB = B.view( batch, self.K, self.c_n, h*w ).permute(0,2,1,3).view( batch*self.c_n, self.K*h*w )\n",
    "        tmpV = V.view( batch, self.K, self.c_n, h*w ).permute(0,1,3,2).contiguous().view( int(b*h*w), self.c_n )\n",
    "\n",
    "        softmaxB = self.softmax(tmpB).view( batch, self.c_n, self.K*h*w ).permute( 0, 2, 1)  #batch, self.K*h*w, self.c_n\n",
    "        softmaxV = self.softmax(tmpV).view( batch, self.K*h*w, self.c_n ).permute( 0, 2, 1)  #batch, self.c_n  , self.K*h*w\n",
    "\n",
    "        tmpG     = tmpA.matmul( softmaxB )      #batch, self.c_m, self.c_n\n",
    "        tmpZ     = tmpG.matmul( softmaxV ) #batch, self.c_m, self.K*h*w\n",
    "        tmpZ     = tmpZ.view(batch, self.c_m, self.K,h*w).permute( 0, 2, 1,3).view( int(b), self.c_m, h, w )\n",
    "\n",
    "        return tmpZ"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 65,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.nn as nn\n",
    "\n",
    "\n",
    "def conv3x3(in_planes, out_planes, stride=1):\n",
    "    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)\n",
    "\n",
    "def conv1x1(in_planes, out_planes, stride=1):\n",
    "    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)\n",
    "\n",
    "\n",
    "class BasicBlock(nn.Module):\n",
    "    expansion = 1\n",
    "\n",
    "    def __init__(self, inplanes, planes, stride=1, downsample=None, *args, **kwargs):\n",
    "        super(BasicBlock, self).__init__()\n",
    "        self.conv1 = conv3x3(inplanes, planes, stride)\n",
    "        self.bn1 = nn.BatchNorm2d(planes)\n",
    "        self.relu = nn.ReLU(inplace=True)\n",
    "        self.conv2 = conv3x3(planes, planes)\n",
    "        self.bn2 = nn.BatchNorm2d(planes)\n",
    "        #self.attention = DoubleAttention(planes)\n",
    "        \n",
    "        if inplanes != planes:\n",
    "            self.downsample = nn.Sequential(conv1x1(inplanes, planes, stride), nn.BatchNorm2d(planes))\n",
    "        else:\n",
    "            self.downsample = lambda x: x\n",
    "        self.stride = stride\n",
    "\n",
    "    def forward(self, x):\n",
    "        residual = self.downsample(x)\n",
    "\n",
    "        out = self.conv1(x)\n",
    "        out = self.bn1(out)\n",
    "        out = self.relu(out)\n",
    "        #out = self.attention(out)\n",
    "        out = self.conv2(out)\n",
    "        out = self.bn2(out)\n",
    "\n",
    "        out += residual\n",
    "        out = self.relu(out)\n",
    "\n",
    "        return out\n",
    "\n",
    "\n",
    "class Bottleneck(nn.Module):\n",
    "    expansion = 4\n",
    "\n",
    "    def __init__(self, inplanes, planes, stride=1, downsample=None):\n",
    "        super(Bottleneck, self).__init__()\n",
    "        self.conv1 = conv1x1(inplanes, planes)\n",
    "        self.bn1 = nn.BatchNorm2d(planes)\n",
    "        self.conv2 = conv3x3(planes, planes, stride)\n",
    "        self.bn2 = nn.BatchNorm2d(planes)\n",
    "        self.conv3 = conv1x1(planes, planes * self.expansion)\n",
    "        self.bn3 = nn.BatchNorm2d(planes * self.expansion)\n",
    "        self.relu = nn.ReLU(inplace=True)\n",
    "        self.downsample = downsample\n",
    "        self.stride = stride\n",
    "        \n",
    "\n",
    "    def forward(self, x):\n",
    "        identity = x\n",
    "\n",
    "        out = self.conv1(x)\n",
    "        out = self.bn1(out)\n",
    "        out = self.relu(out)\n",
    "\n",
    "        out = self.conv2(out)\n",
    "        out = self.bn2(out)\n",
    "        out = self.relu(out)\n",
    "\n",
    "        out = self.conv3(out)\n",
    "        out = self.bn3(out)\n",
    "\n",
    "        if self.downsample is not None:\n",
    "            identity = self.downsample(x)\n",
    "\n",
    "        out += identity\n",
    "        out = self.relu(out)\n",
    "\n",
    "        return out\n",
    "\n",
    "\n",
    "class ResNet(nn.Module):\n",
    "\n",
    "    def __init__(self, block, layers, num_classes=100, zero_init_residual=False):\n",
    "        super(ResNet, self).__init__()\n",
    "        self.inplanes = 64\n",
    "        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)\n",
    "        self.bn1 = nn.BatchNorm2d(64)\n",
    "        self.relu = nn.ReLU(inplace=True)\n",
    "        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)\n",
    "        self.layer1 = self._make_layer(block, 64, layers[0])\n",
    "        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)\n",
    "        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)\n",
    "        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)\n",
    "        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))\n",
    "        self.fc = nn.Linear(512 * block.expansion, num_classes)\n",
    "\n",
    "        for m in self.modules():\n",
    "            if isinstance(m, nn.Conv2d):\n",
    "                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')\n",
    "            elif isinstance(m, nn.BatchNorm2d):\n",
    "                nn.init.constant_(m.weight, 1)\n",
    "                nn.init.constant_(m.bias, 0)\n",
    "\n",
    "        if zero_init_residual:\n",
    "            for m in self.modules():\n",
    "                if isinstance(m, Bottleneck):\n",
    "                    nn.init.constant_(m.bn3.weight, 0)\n",
    "                elif isinstance(m, BasicBlock):\n",
    "                    nn.init.constant_(m.bn2.weight, 0)\n",
    "\n",
    "    def _make_layer(self, block, planes, blocks, stride=1):\n",
    "        downsample = stride != 1\n",
    "\n",
    "        layers = []\n",
    "        layers.append(block(self.inplanes, planes, stride, downsample))\n",
    "        self.inplanes = planes * block.expansion\n",
    "        for _ in range(1, blocks):\n",
    "            layers.append(block(self.inplanes, planes, 1))\n",
    "\n",
    "        return nn.Sequential(*layers)\n",
    "\n",
    "    def forward(self, x):\n",
    "        out = x\n",
    "        out = self.conv1(out)\n",
    "        out = self.bn1(out)\n",
    "        out = self.relu(out)\n",
    "        out = self.maxpool(out)\n",
    "\n",
    "        out = self.layer1(out)\n",
    "        out = self.layer2(out)\n",
    "        out = self.layer3(out)\n",
    "        out = self.layer4(out)\n",
    "\n",
    "        out = self.avgpool(out)\n",
    "        out = out.view(out.size(0), -1)\n",
    "        out = self.fc(out)\n",
    "\n",
    "        return out\n",
    "\n",
    "\n",
    "def resnet18():\n",
    "    return ResNet(BasicBlock, [2, 2, 2, 2])\n",
    "\n",
    "def resnet34():\n",
    "    return ResNet(BasicBlock, [3, 4, 6, 3])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 66,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "0 : accuracy= 15.92\n",
      "1 : accuracy= 22.07\n",
      "2 : accuracy= 19.03\n",
      "3 : accuracy= 30.03\n",
      "4 : accuracy= 27.13\n",
      "5 : accuracy= 31.44\n",
      "6 : accuracy= 27.91\n",
      "7 : accuracy= 25.03\n",
      "8 : accuracy= 30.42\n",
      "9 : accuracy= 33.16\n",
      "10 : accuracy= 33.86\n",
      "11 : accuracy= 32.34\n",
      "12 : accuracy= 34.62\n",
      "13 : accuracy= 35.38\n",
      "14 : accuracy= 36.89\n",
      "15 : "
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "Input \u001b[1;32mIn [66]\u001b[0m, in \u001b[0;36m<cell line: 53>\u001b[1;34m()\u001b[0m\n\u001b[0;32m     54\u001b[0m \u001b[38;5;28mprint\u001b[39m(epoch,\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m: \u001b[39m\u001b[38;5;124m\"\u001b[39m,end\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m     55\u001b[0m train(epoch, model, device, train_loader, optimizer)\n\u001b[1;32m---> 56\u001b[0m \u001b[43mtest\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdevice\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtest_loader\u001b[49m\u001b[43m)\u001b[49m\n",
      "Input \u001b[1;32mIn [66]\u001b[0m, in \u001b[0;36mtest\u001b[1;34m(model, device, test_loader)\u001b[0m\n\u001b[0;32m     31\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m batch_idx, (data, target) \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28menumerate\u001b[39m(test_loader):\n\u001b[0;32m     32\u001b[0m     data, target \u001b[38;5;241m=\u001b[39m data\u001b[38;5;241m.\u001b[39mto(device), target\u001b[38;5;241m.\u001b[39mto(device)\n\u001b[1;32m---> 33\u001b[0m     output \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdata\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m     34\u001b[0m     pred \u001b[38;5;241m=\u001b[39m output\u001b[38;5;241m.\u001b[39margmax(dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1\u001b[39m, keepdim\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[0;32m     35\u001b[0m     correct \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m pred\u001b[38;5;241m.\u001b[39meq(target\u001b[38;5;241m.\u001b[39mview_as(pred))\u001b[38;5;241m.\u001b[39msum()\u001b[38;5;241m.\u001b[39mitem()\n",
      "File \u001b[1;32m~\\AppData\\Local\\conda\\conda\\envs\\torchgpu\\lib\\site-packages\\torch\\nn\\modules\\module.py:1110\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *input, **kwargs)\u001b[0m\n\u001b[0;32m   1106\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m   1107\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m   1108\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m   1109\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1110\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m forward_call(\u001b[38;5;241m*\u001b[39m\u001b[38;5;28minput\u001b[39m, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m   1111\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[0;32m   1112\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
      "Input \u001b[1;32mIn [65]\u001b[0m, in \u001b[0;36mResNet.forward\u001b[1;34m(self, x)\u001b[0m\n\u001b[0;32m    129\u001b[0m out \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mrelu(out)\n\u001b[0;32m    130\u001b[0m out \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmaxpool(out)\n\u001b[1;32m--> 132\u001b[0m out \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlayer1\u001b[49m\u001b[43m(\u001b[49m\u001b[43mout\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m    133\u001b[0m out \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlayer2(out)\n\u001b[0;32m    134\u001b[0m out \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlayer3(out)\n",
      "File \u001b[1;32m~\\AppData\\Local\\conda\\conda\\envs\\torchgpu\\lib\\site-packages\\torch\\nn\\modules\\module.py:1110\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *input, **kwargs)\u001b[0m\n\u001b[0;32m   1106\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m   1107\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m   1108\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m   1109\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1110\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m forward_call(\u001b[38;5;241m*\u001b[39m\u001b[38;5;28minput\u001b[39m, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m   1111\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[0;32m   1112\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
      "File \u001b[1;32m~\\AppData\\Local\\conda\\conda\\envs\\torchgpu\\lib\\site-packages\\torch\\nn\\modules\\container.py:141\u001b[0m, in \u001b[0;36mSequential.forward\u001b[1;34m(self, input)\u001b[0m\n\u001b[0;32m    139\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;28minput\u001b[39m):\n\u001b[0;32m    140\u001b[0m     \u001b[38;5;28;01mfor\u001b[39;00m module \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m:\n\u001b[1;32m--> 141\u001b[0m         \u001b[38;5;28minput\u001b[39m \u001b[38;5;241m=\u001b[39m \u001b[43mmodule\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[0;32m    142\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28minput\u001b[39m\n",
      "File \u001b[1;32m~\\AppData\\Local\\conda\\conda\\envs\\torchgpu\\lib\\site-packages\\torch\\nn\\modules\\module.py:1110\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *input, **kwargs)\u001b[0m\n\u001b[0;32m   1106\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m   1107\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m   1108\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m   1109\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1110\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m forward_call(\u001b[38;5;241m*\u001b[39m\u001b[38;5;28minput\u001b[39m, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m   1111\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[0;32m   1112\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
      "Input \u001b[1;32mIn [65]\u001b[0m, in \u001b[0;36mBasicBlock.forward\u001b[1;34m(self, x)\u001b[0m\n\u001b[0;32m     35\u001b[0m \u001b[38;5;66;03m#out = self.attention(out)\u001b[39;00m\n\u001b[0;32m     36\u001b[0m out \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconv2(out)\n\u001b[1;32m---> 37\u001b[0m out \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbn2\u001b[49m\u001b[43m(\u001b[49m\u001b[43mout\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m     39\u001b[0m out \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m residual\n\u001b[0;32m     40\u001b[0m out \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mrelu(out)\n",
      "File \u001b[1;32m~\\AppData\\Local\\conda\\conda\\envs\\torchgpu\\lib\\site-packages\\torch\\nn\\modules\\module.py:1110\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *input, **kwargs)\u001b[0m\n\u001b[0;32m   1106\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m   1107\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m   1108\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m   1109\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1110\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m forward_call(\u001b[38;5;241m*\u001b[39m\u001b[38;5;28minput\u001b[39m, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m   1111\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[0;32m   1112\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
      "File \u001b[1;32m~\\AppData\\Local\\conda\\conda\\envs\\torchgpu\\lib\\site-packages\\torch\\nn\\modules\\batchnorm.py:168\u001b[0m, in \u001b[0;36m_BatchNorm.forward\u001b[1;34m(self, input)\u001b[0m\n\u001b[0;32m    161\u001b[0m     bn_training \u001b[38;5;241m=\u001b[39m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mrunning_mean \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m) \u001b[38;5;129;01mand\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mrunning_var \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m)\n\u001b[0;32m    163\u001b[0m \u001b[38;5;124mr\u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[0;32m    164\u001b[0m \u001b[38;5;124;03mBuffers are only updated if they are to be tracked and we are in training mode. Thus they only need to be\u001b[39;00m\n\u001b[0;32m    165\u001b[0m \u001b[38;5;124;03mpassed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are\u001b[39;00m\n\u001b[0;32m    166\u001b[0m \u001b[38;5;124;03mused for normalization (i.e. in eval mode when buffers are not None).\u001b[39;00m\n\u001b[0;32m    167\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m--> 168\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mF\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbatch_norm\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m    169\u001b[0m \u001b[43m    \u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[0;32m    170\u001b[0m \u001b[43m    \u001b[49m\u001b[38;5;66;43;03m# If buffers are not to be tracked, ensure that they won't be updated\u001b[39;49;00m\n\u001b[0;32m    171\u001b[0m \u001b[43m    \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrunning_mean\u001b[49m\n\u001b[0;32m    172\u001b[0m \u001b[43m    \u001b[49m\u001b[38;5;28;43;01mif\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;129;43;01mnot\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtraining\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01mor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrack_running_stats\u001b[49m\n\u001b[0;32m    173\u001b[0m \u001b[43m    \u001b[49m\u001b[38;5;28;43;01melse\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[0;32m    174\u001b[0m \u001b[43m    \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrunning_var\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mif\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;129;43;01mnot\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtraining\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01mor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrack_running_stats\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01melse\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[0;32m    175\u001b[0m \u001b[43m    \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mweight\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m    176\u001b[0m \u001b[43m    \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbias\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m    177\u001b[0m \u001b[43m    \u001b[49m\u001b[43mbn_training\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m    178\u001b[0m \u001b[43m    \u001b[49m\u001b[43mexponential_average_factor\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m    179\u001b[0m \u001b[43m    \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43meps\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m    180\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[1;32m~\\AppData\\Local\\conda\\conda\\envs\\torchgpu\\lib\\site-packages\\torch\\nn\\functional.py:2421\u001b[0m, in \u001b[0;36mbatch_norm\u001b[1;34m(input, running_mean, running_var, weight, bias, training, momentum, eps)\u001b[0m\n\u001b[0;32m   2418\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m training:\n\u001b[0;32m   2419\u001b[0m     _verify_batch_size(\u001b[38;5;28minput\u001b[39m\u001b[38;5;241m.\u001b[39msize())\n\u001b[1;32m-> 2421\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbatch_norm\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m   2422\u001b[0m \u001b[43m    \u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mweight\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbias\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mrunning_mean\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mrunning_var\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtraining\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmomentum\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43meps\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbackends\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcudnn\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43menabled\u001b[49m\n\u001b[0;32m   2423\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n",
      "\u001b[1;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "import torchvision\n",
    "import torch\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "from torchvision import datasets, transforms\n",
    "from torch.autograd import Variable\n",
    "\n",
    "\n",
    "\n",
    "#GPU = torch.cuda.is_available()\n",
    "\n",
    "\n",
    "def train(epoch, model, device, train_loader, optimizer):\n",
    "    model.train()\n",
    "\n",
    "    for batch_idx, (data, target) in enumerate(train_loader):\n",
    "        data, target = data.to(device), target.to(device)\n",
    "        optimizer.zero_grad()\n",
    "        output = model(data)\n",
    "    \n",
    "        loss = F.cross_entropy(output, target)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "def test(model, device, test_loader):\n",
    "    model.eval()\n",
    "    correct = 0\n",
    "    with torch.no_grad():\n",
    "        for batch_idx, (data, target) in enumerate(test_loader):\n",
    "            data, target = data.to(device), target.to(device)\n",
    "            output = model(data)\n",
    "            pred = output.argmax(dim=1, keepdim=True)\n",
    "            correct += pred.eq(target.view_as(pred)).sum().item()\n",
    "    \n",
    "    print('accuracy=',100. * correct/len(test_loader.dataset))\n",
    "\n",
    "\n",
    "\n",
    "device = torch.device(\"cuda\")\n",
    "model = resnet18().to(device)\n",
    "optimizer = optim.SGD(model.parameters(), lr=0.03)\n",
    "\n",
    "train_data = torchvision.datasets.CIFAR100('./dataset',\n",
    "                    transform=transforms.Compose([transforms.ToTensor()]), download=True)\n",
    "train_loader = torch.utils.data.DataLoader(train_data, batch_size=128, shuffle=True)\n",
    "test_data = torchvision.datasets.CIFAR100('./dataset',\n",
    "                    transform=transforms.Compose([transforms.ToTensor()]),  train=False, download=True)\n",
    "test_loader = torch.utils.data.DataLoader(test_data)\n",
    "\n",
    "\n",
    "for epoch in range(100):\n",
    "    print(epoch,\": \",end=\"\")\n",
    "    train(epoch, model, device, train_loader, optimizer)\n",
    "    test(model, device, test_loader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 67,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "True\n",
      "1.11.0\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "print(torch.cuda.is_available())\n",
    "print (torch.__version__)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
